package gpgsql

import (
	"context"
	"fmt"
	"html/template"
	"strconv"
	"strings"
)

type DBOrm struct {
	Parent *DBGroup
}

func newDBOrm(parent *DBGroup) (dbOrm *DBOrm, err error) {
	dbOrm = &DBOrm{
		Parent: parent,
	}
	return
}

func (dbOrm *DBOrm) Escape(sql string) string {
	return template.HTMLEscapeString(sql)
}

func (dbOrm *DBOrm) buildWhereCond(bindParams []interface{}, whereConds map[string]interface{}) (whereCondsStr string, respBindParams []interface{}) {
	whereCondsStack := []string{}
	for k, v := range whereConds {
		k = strings.TrimSpace(k)
		k = strings.ToLower(k)
		if strings.HasSuffix(k, "in") {
			list, ok := v.([]interface{})
			if ok {
				valueArray := []string{}
				for _, ele := range list {
					bindParams = append(bindParams, ele)
					valueArray = append(valueArray, "?")
				}
				valueStr := strings.Join(valueArray, ",")
				whereCondsStack = append(whereCondsStack, fmt.Sprintf("%s (%s)", k, valueStr))
			}
		} else if v == nil && (strings.HasSuffix(k, "is") || strings.HasSuffix(k, "is not")) {
			whereCondsStack = append(whereCondsStack, fmt.Sprintf("%s NULL", k))
		} else {
			bindParams = append(bindParams, v)
			whereCondsStack = append(whereCondsStack, fmt.Sprintf("%s ?", k))
		}
	}
	whereCondsStr = strings.Join(whereCondsStack, " AND ")
	if len(whereCondsStr) > 0 {
		whereCondsStr = " WHERE " + whereCondsStr
	}
	respBindParams = bindParams
	return
}

func (dbOrm *DBOrm) buildUpdateCond(bindParams []interface{}, updateMap map[string]interface{}) (updateStr string, respBindParams []interface{}) {
	updateStack := []string{}
	for k, v := range updateMap {
		k = dbOrm.Escape(k)
		k = strings.ToLower(k)
		k = strings.TrimSpace(k)
		if strings.HasSuffix(k, "+") {
			k = strings.TrimRight(k, "+")
			bindParams = append(bindParams, v)
			updateStack = append(updateStack, fmt.Sprintf("%s=%s+?", k, k))
		} else if strings.HasSuffix(k, "-") {
			k = strings.TrimRight(k, "-")
			bindParams = append(bindParams, v)
			updateStack = append(updateStack, fmt.Sprintf("%s=%s-?", k, k))
		} else if v == "NOW()" {
			updateStack = append(updateStack, fmt.Sprintf("%s=NOW()", k))
		} else if v == nil {
			updateStack = append(updateStack, fmt.Sprintf("%s=NULL", k))
		} else {
			bindParams = append(bindParams, v)
			updateStack = append(updateStack, fmt.Sprintf("%s=?", k))
		}
	}
	updateStr = strings.Join(updateStack, ",")

	respBindParams = bindParams
	return updateStr, respBindParams
}

// 删除
func (dbOrm *DBOrm) Delete(context context.Context, table string, whereConds map[string]interface{}, dbOptions *DBOptions) (result int64, err error) {
	if dbOptions == nil {
		dbOptions = DEFAULT_DBOPTION
	}
	if len(table) <= 0 || whereConds == nil {
		err = ERR_INVALID_PARAM
		return
	}
	bindParams := []interface{}{}
	var whereCondsStr string
	whereCondsStr, bindParams = dbOrm.buildWhereCond(bindParams, whereConds)

	sql := fmt.Sprintf("DELETE FROM %s %s", table, whereCondsStr)
	result, err = dbOrm.Parent.Exec(context, sql, bindParams...)
	return
}

// 插入
func (dbOrm *DBOrm) Insert(context context.Context, table string, insertMap map[string]interface{}, dbOptions *DBOptions) (result int64, err error) {
	if dbOptions == nil {
		dbOptions = DEFAULT_DBOPTION
	}
	if len(table) <= 0 || insertMap == nil || len(insertMap) <= 0 {
		err = ERR_INVALID_PARAM
		return
	}
	insertKeyStack := []string{}
	insertValueStack := []string{}
	bindParams := []interface{}{}
	for k, v := range insertMap {
		k = dbOrm.Escape(k)
		if "NOW()" == v {
			insertKeyStack = append(insertKeyStack, k)
			insertValueStack = append(insertValueStack, "NOW()")
		} else if v == nil {
			insertKeyStack = append(insertKeyStack, k)
			insertValueStack = append(insertValueStack, "NULL")
		} else {
			bindParams = append(bindParams, v)
			insertKeyStack = append(insertKeyStack, k)
			insertValueStack = append(insertValueStack, "?")
		}
	}
	insertKeyStr := strings.Join(insertKeyStack, ",")
	insertValueStr := strings.Join(insertValueStack, ",")

	sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, insertKeyStr, insertValueStr)

	result, err = dbOrm.Parent.Exec(context, sql, bindParams...)
	return
}

// 更新
func (dbOrm *DBOrm) Update(context context.Context, table string, updateMap map[string]interface{}, whereConds map[string]interface{}, dbOptions *DBOptions) (result int64, err error) {
	if dbOptions == nil {
		dbOptions = DEFAULT_DBOPTION
	}
	if updateMap == nil || len(table) <= 0 {
		err = ERR_INVALID_PARAM
		return
	}
	if whereConds == nil || len(whereConds) <= 0 {
		err = ERR_INVALID_PARAM
		return
	}
	bindParams := []interface{}{}
	var updateStr, whereCondsStr string
	updateStr, bindParams = dbOrm.buildUpdateCond(bindParams, updateMap)
	whereCondsStr, bindParams = dbOrm.buildWhereCond(bindParams, whereConds)
	sql := fmt.Sprintf("UPDATE %s SET %s %s", table, updateStr, whereCondsStr)
	result, err = dbOrm.Parent.Exec(context, sql, bindParams...)
	return
}

// 查询
func (dbOrm *DBOrm) Select(context context.Context, result interface{}, table string, whereConds map[string]interface{}, otherConds *DBOtherConds, dbOptions *DBOptions) (err error) {
	if dbOptions == nil {
		dbOptions = DEFAULT_DBOPTION
	}

	if otherConds == nil {
		otherConds = DEFAULT_OTHER_CONDS
	}
	if len(otherConds.Fields) <= 0 {
		otherConds.Fields = "*"
	}

	bindParams := []interface{}{}
	var whereCondsStr, groupByStr, orderByStr, limitStr string

	whereCondsStr, bindParams = dbOrm.buildWhereCond(bindParams, whereConds)
	if len(otherConds.GroupBy) > 0 {
		groupByStr = "GROUP BY " + otherConds.GroupBy
		groupByStr = dbOrm.Escape(groupByStr)
	}
	if len(otherConds.OrderBy) > 0 {
		orderByStr = "ORDER BY " + otherConds.OrderBy
		orderByStr = dbOrm.Escape(orderByStr)
	}
	if len(otherConds.Limit) > 0 {
		limit := strings.TrimSpace(otherConds.Limit)
		limitArray := strings.Split(limit, ",")
		n := len(limitArray)
		if n == 2 {
			offset, _ := strconv.Atoi(limitArray[0])
			limitInfo, _ := strconv.Atoi(limitArray[1])
			bindParams = append(bindParams, offset, limitInfo)
			limitStr = "LIMIT ?,?"
		} else if n == 1 {
			limitInfo, _ := strconv.Atoi(limitArray[0])
			bindParams = append(bindParams, limitInfo)
			limitStr = "LIMIT ?"
		}
	}
	sql := fmt.Sprintf("SELECT %s FROM %s %s %s %s %s", otherConds.Fields, table, whereCondsStr, groupByStr, orderByStr, limitStr)
	return dbOrm.Parent.QueryWithOptions(context, dbOptions, result, sql, bindParams...)
}
